import dataclasses
import io
import typing as T
from pathlib import Path

import numpy as np
import pydub
import streamlit as st
from PIL import Image

from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util


def render() -> None:
    st.subheader("🎭 Interpolation")
    st.write(
        """
    Interpolate between prompts in the latent space.
    """
    )

    with st.expander("Help", False):
        st.write(
            """
            This tool allows specifying two endpoints and generating a long-form interpolation
            between them that traverses the latent space. The interpolation is generated by
            the method described at https://www.riffusion.com/about. A seed image is used to
            set the beat and tempo of the generated audio, and can be set in the sidebar.
            Usually the seed is changed or the prompt, but not both at once. You can browse
            infinite variations of the same prompt by changing the seed.

            For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
            This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
            or denoising to get different variations.
            """
        )

    # Sidebar params

    device = streamlit_util.select_device(st.sidebar)
    extension = streamlit_util.select_audio_extension(st.sidebar)

    num_interpolation_steps = T.cast(
        int,
        st.sidebar.number_input(
            "Interpolation steps",
            value=12,
            min_value=1,
            max_value=20,
            help="Number of model generations between the two prompts. Controls the duration.",
        ),
    )

    num_inference_steps = T.cast(
        int,
        st.sidebar.number_input(
            "Steps per sample", value=50, help="Number of denoising steps per model run"
        ),
    )

    guidance = st.sidebar.number_input(
        "Guidance",
        value=7.0,
        help="How much the model listens to the text prompt",
    )

    init_image_name = st.sidebar.selectbox(
        "Seed image",
        # TODO(hayk): Read from directory
        options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
        index=0,
        help="Which seed image to use for img2img. Custom allows uploading your own.",
    )
    assert init_image_name is not None
    if init_image_name == "custom":
        init_image_file = st.sidebar.file_uploader(
            "Upload a custom seed image",
            type=streamlit_util.IMAGE_EXTENSIONS,
            label_visibility="collapsed",
        )
        if init_image_file:
            st.sidebar.image(init_image_file)

    alpha_power = st.sidebar.number_input("Alpha Power", value=1.0)

    show_individual_outputs = st.sidebar.checkbox(
        "Show individual outputs",
        value=False,
        help="Show each model output",
    )
    show_images = st.sidebar.checkbox(
        "Show individual images",
        value=False,
        help="Show each generated image",
    )

    alphas = np.linspace(0, 1, num_interpolation_steps)

    # Apply power scaling to alphas to customize the interpolation curve
    alphas_shifted = alphas * 2 - 1
    alphas_shifted = (np.abs(alphas_shifted) ** alpha_power * np.sign(alphas_shifted) + 1) / 2
    alphas = alphas_shifted

    alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
    st.write(f"**Alphas** : [{alphas_str}]")

    # Prompt inputs A and B in two columns

    with st.form(key="interpolation_form"):
        left, right = st.columns(2)

        with left:
            st.write("##### Prompt A")
            prompt_input_a = PromptInput(
                guidance=guidance, **get_prompt_inputs(key="a", denoising_default=0.75)
            )

        with right:
            st.write("##### Prompt B")
            prompt_input_b = PromptInput(
                guidance=guidance, **get_prompt_inputs(key="b", denoising_default=0.75)
            )

        st.form_submit_button("Generate", type="primary")

    if not prompt_input_a.prompt or not prompt_input_b.prompt:
        st.info("Enter both prompts to interpolate between them")
        return

    if init_image_name == "custom":
        if not init_image_file:
            st.info("Upload a custom seed image")
            return
        init_image = Image.open(init_image_file).convert("RGB")
    else:
        init_image_path = (
            Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
        )
        init_image = Image.open(str(init_image_path)).convert("RGB")

    # TODO(hayk): Move this code into a shared place and add to riffusion.cli
    image_list: T.List[Image.Image] = []
    audio_bytes_list: T.List[io.BytesIO] = []
    for i, alpha in enumerate(alphas):
        inputs = InferenceInput(
            alpha=float(alpha),
            num_inference_steps=num_inference_steps,
            seed_image_id="og_beat",
            start=prompt_input_a,
            end=prompt_input_b,
        )

        if i == 0:
            with st.expander("Example input JSON", expanded=False):
                st.json(dataclasses.asdict(inputs))

        image, audio_bytes = run_interpolation(
            inputs=inputs,
            init_image=init_image,
            device=device,
            extension=extension,
        )

        if show_individual_outputs:
            st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
            if show_images:
                st.image(image)
            st.audio(audio_bytes)

        image_list.append(image)
        audio_bytes_list.append(audio_bytes)

    st.write("#### Final Output")

    # TODO(hayk): Concatenate with overlap and better blending like in audio to audio
    audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
    concat_segment = audio_segments[0]
    for segment in audio_segments[1:]:
        concat_segment = concat_segment.append(segment, crossfade=0)

    audio_bytes = io.BytesIO()
    concat_segment.export(audio_bytes, format=extension)
    audio_bytes.seek(0)

    st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
    st.audio(audio_bytes)

    output_name = (
        f"{prompt_input_a.prompt.replace(' ', '_')}_"
        f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
    )
    st.download_button(
        output_name,
        data=audio_bytes,
        file_name=output_name,
        mime=f"audio/{extension}",
    )


def get_prompt_inputs(
    key: str,
    include_negative_prompt: bool = False,
    cols: bool = False,
    denoising_default: float = 0.5,
) -> T.Dict[str, T.Any]:
    """
    Compute prompt inputs from widgets.
    """
    p: T.Dict[str, T.Any] = {}

    # Optionally use columns
    left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))

    visibility = "visible" if cols else "collapsed"
    p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")

    if include_negative_prompt:
        p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")

    p["seed"] = T.cast(
        int,
        left.number_input(
            "Seed",
            value=42,
            key=f"seed_{key}",
            help="Integer used to generate a random result. Vary this to explore alternatives.",
        ),
    )

    p["denoising"] = right.number_input(
        "Denoising",
        value=denoising_default,
        key=f"denoising_{key}",
        help="How much to modify the seed image",
    )

    return p


@st.cache_data
def run_interpolation(
    inputs: InferenceInput,
    init_image: Image.Image,
    checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
    device: str = "cuda",
    extension: str = "mp3",
) -> T.Tuple[Image.Image, io.BytesIO]:
    """
    Cached function for riffusion interpolation.
    """
    pipeline = streamlit_util.load_riffusion_checkpoint(
        device=device,
        checkpoint=checkpoint,
        # No trace so we can have variable width
        no_traced_unet=True,
    )

    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=None,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct from image to audio
    audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
        image=image,
        params=params,
        device=device,
        output_format=extension,
    )

    return image, audio_bytes
